{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example Q4 - Quantum Generative Adversarial Network\n",
    "\n",
    "This demo constructs a Quantum Generative Adversarial Network (QGAN) ([Lloyd and Weedbrook (2018)](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.121.040502), [Dallaire-Demers and Killoran (2018)](https://journals.aps.org/pra/abstract/10.1103/PhysRevA.98.012324)) using two subcircuits, a *generator* and a *discriminator*. The generator attempts to generate synthetic quantum data to match a pattern of \"real\" data, while the discriminator, tries to discern real data from fake data. The gradient of the discriminator's output provides a training signal for the generator to improve its fake generated data. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports\n",
    "\n",
    "As usual, we import PennyLane, the PennyLane-provided version of NumPy, and an optimizer. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pennylane as qml\n",
    "from pennylane import numpy as np\n",
    "from pennylane.optimize import GradientDescentOptimizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We also declare a 3-qubit device."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "dev = qml.device('default.qubit', wires=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classical and quantum nodes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In classical GANs, the starting point is to draw samples either from some \"real data\" distribution, or from the generator, and feed them to the discriminator. In this QGAN example, we will use a quantum circuit to generate the real data.\n",
    "\n",
    "For this simple example, our real data will be a qubit that has been rotated (from the starting state $\\left|0\\right\\rangle$) to some arbitrary, but fixed, state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def real(phi, theta, omega):\n",
    "    qml.Rot(phi, theta, omega, wires=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For the generator and discriminator, we will choose the same basic circuit structure, but acting on different wires. \n",
    "\n",
    "Both the real data circuit and the generator will output on wire 0, which will be connected as an input to the discriminator. Wire 1 is provided as a workspace for the generator, while the discriminator's output will be on wire 2."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generator(w):\n",
    "    qml.RX(w[0], wires=0)\n",
    "    qml.RX(w[1], wires=1)\n",
    "    qml.RY(w[2], wires=0)\n",
    "    qml.RY(w[3], wires=1)\n",
    "    qml.RZ(w[4], wires=0)\n",
    "    qml.RZ(w[5], wires=1)\n",
    "    qml.CNOT(wires=[0,1])\n",
    "    qml.RX(w[6], wires=0)\n",
    "    qml.RY(w[7], wires=0)\n",
    "    qml.RZ(w[8], wires=0)\n",
    "    \n",
    "def discriminator(w):\n",
    "    qml.RX(w[0], wires=0)\n",
    "    qml.RX(w[1], wires=2)\n",
    "    qml.RY(w[2], wires=0)\n",
    "    qml.RY(w[3], wires=2)\n",
    "    qml.RZ(w[4], wires=0)\n",
    "    qml.RZ(w[5], wires=2)\n",
    "    qml.CNOT(wires=[1,2])\n",
    "    qml.RX(w[6], wires=2)\n",
    "    qml.RY(w[7], wires=2)\n",
    "    qml.RZ(w[8], wires=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We create two QNodes. One where the real data source is wired up to the discriminator, and one where the generator is connected to the discriminator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "@qml.qnode(dev)\n",
    "def real_disc_circuit(phi, theta, omega, disc_weights):\n",
    "    real(phi, theta, omega)\n",
    "    discriminator(disc_weights)\n",
    "    return qml.expval.PauliZ(2)\n",
    "\n",
    "@qml.qnode(dev)\n",
    "def gen_disc_circuit(gen_weights, disc_weights):\n",
    "    generator(gen_weights)\n",
    "    discriminator(disc_weights)\n",
    "    return qml.expval.PauliZ(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cost"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are two ingredients to the cost here. The first is the probability that the discriminator correctly classifies real data as real. The second ingredient is the probability that the discriminator classifies fake data (i.e., a state prepared by the generator) as real. \n",
    "\n",
    "The discriminator's objective is to maximize the probability of correctly classifying real data, while minimizing the probability of mistakenly classifying fake data.\n",
    "\n",
    "The generator's objective is to maximize the probability that the discriminator accepts fake data as real."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prob_real_true(disc_weights):\n",
    "    true_disc_output = real_disc_circuit(phi, theta, omega, disc_weights)\n",
    "    # convert to probability\n",
    "    prob_real_true = (true_disc_output + 1) / 2\n",
    "    return prob_real_true\n",
    "\n",
    "def prob_fake_true(gen_weights, disc_weights):\n",
    "    fake_disc_output = gen_disc_circuit(gen_weights, disc_weights)\n",
    "    # convert to probability\n",
    "    prob_fake_true = (fake_disc_output + 1) / 2\n",
    "    return prob_fake_true # generator wants to minimize this prob\n",
    "\n",
    "def disc_cost(disc_weights):\n",
    "    cost = prob_fake_true(gen_weights, disc_weights) - prob_real_true(disc_weights) \n",
    "    return cost\n",
    "\n",
    "def gen_cost(gen_weights):\n",
    "    return -prob_fake_true(gen_weights, disc_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Optimization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We initialize the fixed angles of the \"real data\" circuit, as well as the initial parameters for both generator and discriminator. These are chosen so that the generator initially prepares a state on wire 0 that is very close to the $\\left| 1 \\right\\rangle$ state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "phi = np.pi / 6\n",
    "theta = np.pi / 2\n",
    "omega = np.pi / 7\n",
    "np.random.seed(0)\n",
    "eps = 1e-2\n",
    "gen_weights = np.array([np.pi] + [0] * 8) + np.random.normal(scale=eps, size=[9])\n",
    "disc_weights = np.random.normal(size=[9])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We begin by creating the optimizer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = GradientDescentOptimizer(0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the first stage of training, we optimize the discriminator while keeping the generator parameters fixed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 1: cost = -0.10942017805789117\n",
      "Step 6: cost = -0.389988422649031\n",
      "Step 11: cost = -0.666019117581563\n",
      "Step 16: cost = -0.8550839212078474\n",
      "Step 21: cost = -0.9454459581664489\n",
      "Step 26: cost = -0.9805878247866409\n",
      "Step 31: cost = -0.9931371328342753\n",
      "Step 36: cost = -0.9974896764916594\n",
      "Step 41: cost = -0.9989863506630716\n",
      "Step 46: cost = -0.9995000463932017\n"
     ]
    }
   ],
   "source": [
    "for it in range(50):\n",
    "    disc_weights = opt.step(disc_cost, disc_weights) \n",
    "    cost = disc_cost(disc_weights)\n",
    "    if it % 5 == 0:\n",
    "        print(\"Step {}: cost = {}\".format(it+1, cost))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At the discriminator's optimum, the probability for the discriminator to correctly classify the real data should be close to one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9998971951842262"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prob_real_true(disc_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For comparison, we check how the discriminator classifies the generator's (still unoptimized) fake data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.00024278396180005268"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prob_fake_true(gen_weights, disc_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the adverserial game we have to now train the generator to better fool the discriminator (we can continue training the models in an alternating fashion until we reach the optimum point of the two-player adversarial game)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 0: cost = 0.00026646913829919683\n",
      "Step 5: cost = 0.0004266200858935587\n",
      "Step 10: cost = 0.0006872486146978218\n",
      "Step 15: cost = 0.0011111626380132522\n",
      "Step 20: cost = 0.0018000510248328272\n",
      "Step 25: cost = 0.002917930412544567\n",
      "Step 30: cost = 0.004727717539774079\n",
      "Step 35: cost = 0.007646628881031681\n",
      "Step 40: cost = 0.012325866735736102\n",
      "Step 45: cost = 0.019754518934526843\n",
      "Step 50: cost = 0.03136834673567135\n",
      "Step 55: cost = 0.049097345993078356\n",
      "Step 60: cost = 0.07520378135265471\n",
      "Step 65: cost = 0.11169015288702144\n",
      "Step 70: cost = 0.15917286333740993\n",
      "Step 75: cost = 0.21566031343947306\n",
      "Step 80: cost = 0.2763735721045272\n",
      "Step 85: cost = 0.3354169186527467\n",
      "Step 90: cost = 0.38835012669286345\n",
      "Step 95: cost = 0.43371772120148666\n",
      "Step 100: cost = 0.47284901883928354\n",
      "Step 105: cost = 0.5087778323625858\n",
      "Step 110: cost = 0.5451977336157693\n",
      "Step 115: cost = 0.5856632916397928\n",
      "Step 120: cost = 0.6327897835085264\n",
      "Step 125: cost = 0.6872469221106684\n",
      "Step 130: cost = 0.7468453348018502\n",
      "Step 135: cost = 0.8066413637587722\n",
      "Step 140: cost = 0.8607353038575882\n",
      "Step 145: cost = 0.9048413990478879\n",
      "Step 150: cost = 0.9376687441862386\n",
      "Step 155: cost = 0.9604104860258109\n",
      "Step 160: cost = 0.975371147847835\n",
      "Step 165: cost = 0.9848746701785995\n",
      "Step 170: cost = 0.9907765008479259\n",
      "Step 175: cost = 0.9943898535953373\n",
      "Step 180: cost = 0.9965827747752589\n",
      "Step 185: cost = 0.9979065439048096\n",
      "Step 190: cost = 0.9987030608439529\n",
      "Step 195: cost = 0.9991813935606384\n"
     ]
    }
   ],
   "source": [
    "for it in range(200):\n",
    "    gen_weights = opt.step(gen_cost, gen_weights)\n",
    "    cost = -gen_cost(gen_weights)\n",
    "    if it % 5 == 0:\n",
    "        print(\"Step {}: cost = {}\".format(it, cost))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At the optimum of the generator, the probability for the discriminator to be fooled should be close to 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9998971951842262"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prob_real_true(disc_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At the joint optimum the overall cost will be close to zero."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-0.0004751627422100446"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "disc_cost(disc_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The generator has successfully learned how to simulate the real data enough to fool the discriminator."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
